/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "Flow.h"
#include "Protocol.h"

namespace aiengine {

void Flow::setFiveTuple(uint32_t src_a, uint16_t src_p, uint16_t proto, uint32_t dst_a, uint16_t dst_p) {

	address_.setSourceAddress(src_a);
	address_.setDestinationAddress(dst_a);
	address_.setType(IPPROTO_IP);
	source_port_ = src_p;
	dest_port_ = dst_p;
	protocol_ = proto;
}

void Flow::setFiveTuple6(struct in6_addr *src_a, uint16_t src_p, uint16_t proto, struct in6_addr *dst_a, uint16_t dst_p) {

        address_.setSourceAddress6(src_a);
        address_.setDestinationAddress6(dst_a);
	address_.setType(IPPROTO_IPV6);
        source_port_ = src_p;
        dest_port_ = dst_p;
        protocol_ = proto;
}

void Flow::reset() {

	hash_ = 0;
	total_bytes[static_cast<int>(FlowDirection::FORWARD)] = 0;
	total_bytes[static_cast<int>(FlowDirection::BACKWARD)] = 0;
	total_packets[static_cast<int>(FlowDirection::FORWARD)] = 0;
	total_packets[static_cast<int>(FlowDirection::BACKWARD)] = 0;
	total_packets_l7 = 0;
	address_.reset();
	source_port_ = 0;
	dest_port_ = 0;
	protocol_ = 0;
	have_tag_ = false;
	reject_ = false;
	partial_reject_ = false;
	have_evidence_ = false;
	write_matched_packet_ = false;
	tag_ = 0xFFFFFFFF;
	ipset.reset();
	forwarder.reset();

	// Reset layer4 object attach
	layer4info.reset();
	// Reset layer7 object attach
	layer7info.reset();

	// Reset frequencies objects
	frequencies.reset();
	packet_frequencies.reset();

	regex.reset();
	regex_mng.reset();
	packet = nullptr;
	frequency_engine_inspected_ = false;
	prev_direction_ = direction_ = FlowDirection::FORWARD;
	pa_ = PacketAnomalyType::NONE;
	arrive_time_ = 0;
	current_time_ = 0;
	label_.reset();
#if defined(BINDING)
        is_accept_ = true;
        is_alerted_ = false;
	l7_payload.reset();
        total_drop_packets = 0;
        total_drop_bytes = 0;
#endif
	upstream_ttl = 0;
	downstream_ttl = 0;
}

void Flow::show(std::ostream &out) const {

	std::ios_base::fmtflags f(out.flags());

	if (haveTag())
		out << "Tag:" << getTag() << " ";

	out << "TTL(" << (int)upstream_ttl << "," << (int)downstream_ttl << ")";

        if (getPacketAnomaly() != PacketAnomalyType::NONE)
		out << " Anomaly:" << getFlowAnomalyString();

        if (ipset.lock())
		out << " IPset:" << ipset.lock()->name();

        if ((label_)and(label_->length() > 0))
                out << " Label:" << *label_;

	out << " ";

	showL4(out);
	showL7(out);

        if (!regex.expired())
		out << " Regex:" << regex.lock()->name();

#if defined(HAVE_REJECT_FLOW)
	if (isPartialReject())
		out << " Rejected";
#endif
	if (frequencies) {
		out << " Dispersion(" << frequencies->getDispersion() << ")"
			<< "Enthropy(" << std::setprecision(4) << frequencies->getEntropy() << ")"
			<< "Packets(" << frequencies->packets_inspected << ") "
			<< boost::format("%-8s") % frequencies->getFrequenciesString();
	}

	// Restore the flags dure to the std::setprecision call
	out.flags(f);

	return;
}

void Flow::showL4(std::ostream &out) const {

	if (protocol_ == IPPROTO_TCP) {
		if (auto tinfo = getTCPInfo(); tinfo)
			out << "TCP:" << *tinfo.get();
	} else {
		if (auto ginfo = getGPRSInfo(); ginfo)
			out << *ginfo.get();
	}
}

void Flow::showL7(std::ostream &out) const {

	if (protocol_ == IPPROTO_TCP) {
		if (auto hinfo = getHTTPInfo()) {
			out << *hinfo.get();
        	} else if (auto sinfo = getSSLInfo()) {
			out << *sinfo.get();
		} else if (auto smtpinfo = getSMTPInfo()) {
			out << *smtpinfo.get();
		} else if (auto popinfo = getPOPInfo()) {
			out << *popinfo.get();
		} else if (auto iinfo = getIMAPInfo()) {
			out << *iinfo.get();
		} else if (auto binfo = getBitcoinInfo()) {
			out << *binfo.get();
		} else if (auto minfo = getMQTTInfo()) {
			out << *minfo.get();
		} else if (auto sinfo = getSMBInfo()) {
			out << *sinfo.get();
		} else if (auto sinfo = getSSHInfo()) {
			out << *sinfo.get();
		} else if (auto dinfo = getDCERPCInfo()) {
			out << *dinfo.get();
		}
	} else {
		if (auto dnsinfo = getDNSInfo()) {
			out << *dnsinfo.get();
		} else if (auto sipinfo = getSIPInfo()) {
			out << *sipinfo.get();
		} else if (auto ssdpinfo = getSSDPInfo()) {
			out << *ssdpinfo.get();
		} else if (auto nbinfo = getNetbiosInfo()) {
			out << *nbinfo.get();
		} else if (auto coapinfo = getCoAPInfo()) {
			out << *coapinfo.get();
		} else if (auto dhcpinfo = getDHCPInfo()) {
			out << *dhcpinfo.get();
		} else if (auto dhcpv6info = getDHCPv6Info()) {
			out << *dhcpv6info.get();
		} else if (auto dinfo = getQuicInfo()) {
			out << *dinfo.get();
		} else if (auto dtlsinfo = getDTLSInfo()) {
			out << *dtlsinfo.get();
		}
	}
	return;
}


std::ostream& operator<< (std::ostream &out, const Flow &flow) {

	out << flow.address_.getSrcAddrDotNotation() << ":"
		<< flow.getSourcePort() << ":"
		<< flow.getProtocol() << ":"
		<< flow.address_.getDstAddrDotNotation() << ":"
		<< flow.getDestinationPort();

        return out;
}

const char* Flow::getL7ProtocolName() const {

	const char *proto_name = "None";

        if (forwarder.lock())
        	if (ProtocolPtr proto = forwarder.lock()->getProtocol(); proto)
			proto_name = proto->name();

        return proto_name;
}

#if defined(PYTHON_BINDING)
boost::python::list Flow::getPayload() const {
	const uint8_t *pkt = packet->getPayload();
	boost::python::list l;

	for (int i = 0; i < packet->getLength(); ++i)
		l.append(pkt[i]);

	return l;
}

void Flow::setRegexManager(const SharedPointer<RegexManager> &rm) {

	if (rm) {
    		regex_mng = rm;
		regex.reset(); // Remove the old Regex if present
	} else {
		// If have a regex dont remove the refence to it
		regex_mng.reset();
	}

}

#elif defined(RUBY_BINDING)
VALUE Flow::getPayload() const {
	VALUE arr = rb_ary_new2(packet->getLength());
	const uint8_t *pkt = packet->getPayload();

	for (int i = 0; i < packet->getLength(); ++i)
		rb_ary_push(arr, INT2NUM((short)pkt[i]));

	return arr;
}
#elif defined(LUA_BINDING)
RawPacket& Flow::getPacket() const {
	static RawPacket pkt(packet->getPayload(), packet->getLength());

	return pkt;
}

const char *Flow::__str__() {
    	std::ostringstream ss;
    	static char flowip[1024];

	ss << *this;
    	snprintf(flowip, 1024, "%s", ss.str().c_str());
    	return flowip;
}

#elif defined(JAVA_BINDING)
IPAbstractSet& Flow::getIPSet() const { return *ipset.lock().get();}
#endif

void Flow::show(Json &out) const {

	out["ip"]["src"] = address_.getSrcAddrDotNotation();
        out["ip"]["dst"] = address_.getDstAddrDotNotation();
        out["port"]["src"] = source_port_;
        out["port"]["dst"] = dest_port_;

	out["upstream"]["ttl"] = upstream_ttl;
	out["upstream"]["packets"] = total_packets[static_cast<int>(FlowDirection::FORWARD)];
        out["upstream"]["bytes"] = total_bytes[static_cast<int>(FlowDirection::FORWARD)];
	out["downstream"]["ttl"] = downstream_ttl;
	out["downstream"]["packets"] = total_packets[static_cast<int>(FlowDirection::BACKWARD)];
        out["downstream"]["bytes"] = total_bytes[static_cast<int>(FlowDirection::BACKWARD)];
	out["layer7"] = getL7ProtocolName();
        out["proto"] = protocol_;

	out["reject"] = reject_;
	out["evidence"] = have_evidence_;
#if defined(BINDING)
	out["accept"] = is_accept_;
	out["drop_packets"] = total_drop_packets;
	out["drop_bytes"] = total_drop_bytes;
#endif
	if (haveTag() == true)
        	out["tag"] = getTag();

        // We return the number of the anomaly, bear in mind that
        // this data is static and dont change
        if (getPacketAnomaly() != PacketAnomalyType::NONE)
                out["anomaly"] = static_cast<std::int8_t>(pa_);

        if ((label_)and(label_->length() > 0))
                out["label"] = label_->c_str();

        if (ipset.lock())
		out["ipset"] = ipset.lock()->name();

        if (protocol_ == IPPROTO_TCP) {
                if (auto tinfo = getTCPInfo(); tinfo)
                        out["tcp"] << *tinfo.get();

                if (auto hinfo = getHTTPInfo()) {
                        out["http"] << *hinfo.get();
                } else if (auto sinfo = getSSLInfo()) {
                        out["ssl"] << *sinfo.get();
                } else if (auto smtpinfo = getSMTPInfo()) {
                        out["smtp"] << *smtpinfo.get();
                } else if (auto popinfo = getPOPInfo()) {
                        out["pop"] << *popinfo.get();
                } else if (auto iinfo = getIMAPInfo()) {
                        out["imap"] << *iinfo.get();
                } else if (auto binfo = getBitcoinInfo()) {
                        out["bitcoin"] << *binfo.get();
                } else if (auto minfo = getMQTTInfo()) {
                        out["mqtt"] << *minfo.get();
                } else if (auto sinfo = getSMBInfo()) {
                        out["smb"] << *sinfo.get();
                } else if (auto sinfo = getSSHInfo()) {
                        out["ssh"] << *sinfo.get();
                } else if (auto dinfo = getDCERPCInfo()) {
                        out["dcerpc"] << *dinfo.get();
                }
        } else {
                if (auto ginfo = getGPRSInfo(); ginfo)
                        out["gprs"] << *ginfo.get();

                if (auto dnsinfo = getDNSInfo()) {
                        out["dns"] << *dnsinfo.get();
                } else if (auto sipinfo = getSIPInfo()) {
                        out["sip"] << *sipinfo.get();
                } else if (auto ssdpinfo = getSSDPInfo()) {
                        out["ssdp"] << *ssdpinfo.get();
                } else if (auto nbinfo = getNetbiosInfo()) {
                        out["netbios"] << *nbinfo.get();
                } else if (auto coapinfo = getCoAPInfo()) {
                        out["coap"] << *coapinfo.get();
                } else if (auto dhcpinfo = getDHCPInfo()) {
                        out["dhcp"] << *dhcpinfo.get();
                } else if (auto dhcpv6info = getDHCPv6Info()) {
                        out["dhcpv6"] << *dhcpv6info.get();
                } else if (auto dinfo = getQuicInfo()) {
                        out["quic"] << *dinfo.get();
                } else if (auto dtlsinfo = getDTLSInfo()) {
                        out["dtls"] << *dtlsinfo.get();
                }
        }

        if (!regex.expired()) {
                 out["matchs"] = regex.lock()->name();
#if defined(BINDING)
                if (write_matched_packet_) {
                        // If is force to write is because the current packet contains the issue
                        const uint8_t *payload = packet->getPayload();
                        std::vector<uint8_t> pkt;

                        for (int i = 0; i < packet->getLength(); ++i)
                                pkt.push_back(payload[i]);

                        out["l7_payload"] = pkt;

                        write_matched_packet_ = false;
                }
#endif
        }

#if defined(HAVE_REJECT_FLOW)
        if (isPartialReject())
                out["rejected"] = true;
#endif
        if (frequencies) {
                out["dispersion"] = frequencies->getDispersion();
                out["enthropy"] = frequencies->getEntropy();
                out["frequencies"] = frequencies->getFrequenciesString();
        }
}

#if defined(BINDING)
void Flow::detach() {

        if (forwarder.lock()) {
        	if (ProtocolPtr proto = forwarder.lock()->getProtocol(); proto)
			proto->releaseFlowInfo(this);

		layer7info.reset();
		forwarder.reset();
	}
}

const char* Flow::getDirection() const {

	if (direction_ == FlowDirection::FORWARD)
		return "upstream";
	else
		return "downstream";
}

#endif

const char* Flow::status(std::time_t now, int timeout) const {

        std::time_t last_packet_seen = now - getLastPacketTime();

        if (last_packet_seen > timeout) {
                if (last_packet_seen > timeout * 2)
                        return "comatose";
                else
                        return "timeout";
        }
	return "active";
}

} // namespace aiengine

